import math
import torch
from torch.distributions.normal import Normal
import torch.nn.functional as F

#ddim
def compute_alpha(beta, t):
    beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
    a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1)
    return a

def generalized_steps(seq, model, b, prototype, eta, variance, last=True, **kwargs):
    with torch.no_grad():
        device = next(model.parameters()).device
        z = torch.normal(0., std = variance).to(device)
        #z = torch.randn_like(y).to(device)
        x = z 
        x_feature = torch.matmul(z,prototype)
        n = x.size(0)
        n_feature = x_feature.size(0)
        seq_next = [-1] + list(seq[:-1])
        x0_preds = []
        xs = [x]
        xs_feature = [x_feature]
        for i, j in zip(reversed(seq), reversed(seq_next)):
            t = (torch.ones(n) * i).to(x.device)
            t_feature = (torch.ones(n_feature) * i).to(x.device)
            next_t = (torch.ones(n) * j).to(x.device)
            next_t_feature = (torch.ones(n_feature) * j).to(x.device)
            at = compute_alpha(b, t.long())
            at_feature = compute_alpha(b, t_feature.long())
            at_next = compute_alpha(b, next_t.long())
            at_next_feature = compute_alpha(b, next_t_feature.long())
            xt = xs[-1].to('cuda')
            xt_feature = xs_feature[-1].to('cuda')
            et = model(xt_feature, t.long())['noise']
            x0_t = (xt - et[:, :et.shape[1]//2] * (1 - at).sqrt()) / at.sqrt()
            x0_t_feature = (xt_feature - torch.matmul(et[:, et.shape[1]//2:], prototype) * (1 - at_feature).sqrt()) / at_feature.sqrt()
            x0_preds.append(x0_t.to('cpu'))
            c1 = (
                eta* ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
            )
            c1_feature = (
                eta* ((1 - at_feature / at_next_feature) * (1 - at_next_feature) / (1 - at_feature)).sqrt()
            )
            c2 = ((1 - at_next) - c1 ** 2).sqrt()
            c2_feature = ((1 - at_next_feature) - c1_feature ** 2).sqrt()
            xt_next = at_next.sqrt() * x0_t + c1 * torch.normal(0., std = variance) + c2 * et[:, :et.shape[1]//2]
            xt_next_feature = at_next_feature.sqrt() * x0_t_feature + c1_feature * torch.matmul(torch.normal(0., std = variance), prototype) + c2_feature * torch.matmul(et[:, et.shape[1]//2:], prototype)
            #print((at_next_feature.sqrt()).shape, (c1_feature * torch.matmul(torch.normal(0., std = variance), prototype)).shape, (c2_feature * torch.matmul(et[:, et.shape[1]//2:], prototype)).shape)
            xt_next = F.softmax(xt_next, dim=1)
            xs.append(xt_next.to('cpu'))
            xs_feature.append(xt_next_feature.to('cpu'))
    if last:
        return xs[-1], xs_feature[-1]
    else:
        return xs, x0_preds, xs_feature
            
if __name__ == "__main__":
    y_batch = torch.rand([4,8])
    print(y_batch)